import torch
import cv2 as cv
from model.zhnnet import ZhnNet, image_generate_loc


model = ZhnNet()
model.load_state_dict(torch.load('zhnnet.pth'))
image_origin = cv.imread('E:/dataset/zhnfacesample.png')
assert image_origin is not None, 'Image does not exit.'
image_origin = cv.copyMakeBorder(image_origin, 16, 16, 0, 0, cv.BORDER_CONSTANT, value=(255, 255, 255))
image = image_origin.transpose(2, 0, 1)/256
print('Network loading complete.')
model.eval()
with torch.no_grad():
    image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)
    predict = model(image).squeeze()
image_origin = image_generate_loc(image_origin, predict)
cv.imshow('test', image_origin)
cv.waitKey(0)
